Add rotary embedding onnx domain support#29261
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the GroupQueryAttention fusion optimizer pass to recognize standard (ONNX-domain) RotaryEmbedding nodes when extracting the cos_cache/sin_cache inputs needed to fuse rotary embedding into the com.microsoft.GroupQueryAttention node.
Changes:
- Adds a helper to retrieve
cos_cache/sin_cacheNodeArgs for bothcom.microsoft.RotaryEmbeddingand ONNX-domainRotaryEmbedding(different input ordering). - Updates the fusion pattern-matching logic to use that helper and to require that rotary cache inputs were successfully identified before fusing.
|
cc @tianleiwu |
tianleiwu
left a comment
There was a problem hiding this comment.
Summary
The change extends GroupQueryAttentionFusion to also match the standard ONNX-domain RotaryEmbedding (X, cos_cache, sin_cache, position_ids) in addition to com.microsoft.RotaryEmbedding, plumbing position_ids through to GQA input 9 and setting the input-arg-count. The approach is sound:
- Requiring
position_idsto be present for the ONNX-domain path (and bailing otherwise) correctly avoids the 3D per-batch cos/sin cache form that GQA's 2D rotary cache validation cannot consume. - The
position_ids_arg_mismatchguard and the addedcos_cache_arg == nullptr || sin_cache_arg == nullptrchecks make the fusion safely skip ambiguous/mixed cases. MutableInputArgsCount()[9]is in-bounds because the GQA schema declares formal inputs up to index 11, soUpdateInputArgCount()sizes the vector accordingly.- Both the fused (with position_ids) and non-fused (omitted position_ids) cases are covered by new tests.
Main concern
Rotary interleaved / rotary_embedding_dim attributes are not validated or propagated. GQA's do_rotary path runs non-interleaved, full-width RoPE (rotary_interleaved defaults to 0 and the fusion never sets it). A standard ONNX RotaryEmbedding with interleaved=1, or a partial-rotary node with rotary_embedding_dim > 0 (and a correspondingly narrower cos/sin cache), is silently fused into a GQA that applies a different rotation — producing incorrect results with no error. Since this PR specifically targets a standard-ONNX export path where interleaved RoPE is common, the fusion should either verify interleaved == 0 and rotary_embedding_dim == 0 (full rotary) before matching, or propagate interleaved to GQA's rotary_interleaved. Inline comment below. (Note: the pre-existing com.microsoft.RotaryEmbedding path has the same latent gap.)
Minor
- Only
position_idsis checked for consistency between the two rotary nodes;cos_cache/sin_cacheare taken from whichever rotary node is visited first without verifying the second uses the same caches. In practice Q/K share caches, but an explicit equality guard (mirroring theposition_ids_arg_mismatchcheck) would make the fusion robust to malformed graphs.
tianleiwu
left a comment
There was a problem hiding this comment.
Thanks for updating the ONNX-domain RotaryEmbedding handling and adding the targeted fusion tests. I found one remaining correctness issue around explicit 2D position_ids: the fused GQA path can use different prompt-time position semantics than ONNX RotaryEmbedding, so I am requesting changes for that case.
| cos_cache_arg, | ||
| sin_cache_arg}; | ||
| if (position_ids_arg != nullptr) { | ||
| gqa_input_defs.push_back(position_ids_arg); |
There was a problem hiding this comment.
This forwards ONNX RotaryEmbedding position_ids into GQA, but GQA does not preserve the same prompt-time semantics. ONNX RotaryEmbedding treats a provided position_ids input as a full (batch_size, sequence_length) tensor and reads every token position. In the fused GQA path, prompt handling uses base-offset semantics: the CPU path sets position_ids_format = !parameters.is_first_prompt ? 1 : 0, and the CUDA paths similarly route through GQA rotary helpers instead of the ONNX op. So for first-prompt/prefill cases with non-contiguous or per-batch custom 2D positions, this fusion can silently rotate Q/K with different positions than the original ONNX nodes.
Can we either skip ONNX-domain fusion unless the position_ids are known to be contiguous/base-offset compatible, or update GQA RoPE to consume the full 2D position_ids tensor for prompt cases before enabling this rewrite?
There was a problem hiding this comment.
disabled the GQA fusion for ONNX-domain RotaryEmbedding for now
Description
Mobius exports standard ONNX rotary embedding op. Adding support for this.
Motivation and Context